{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we begin, we will change a few settings to make the notebook look a bit prettier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "%%html\n",
    "<style> body {font-family: \"Calibri\", cursive, sans-serif;} </style>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 01 - DeepSurvK Quickstart\n",
    "In this notebook, I will show DeepSurvK's basic functionality.\n",
    "\n",
    "Before going forward, I recommend you check the previous notebook,\n",
    "[\"Understanding DeepSurv\"](./00_understanding_deepsurv.ipynb). \n",
    "There, you will learn the working principles of the original DeepSurv \n",
    "algorithm. Furthermore, there are also some useful usage recommendations \n",
    "that also apply here. However, I will just mention them without going\n",
    "into the details.\n",
    "\n",
    "## Preliminaries\n",
    "\n",
    "Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import deepsurvk\n",
    "from deepsurvk.datasets import load_whas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get data\n",
    "For convenience, DeepSurvK comes with DeepSurv's original datasets. \n",
    "This way, we can load sample data very easily (notice the import at the\n",
    "top)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, Y_train, E_train, = load_whas(partition='training', data_type='np')\n",
    "X_test, Y_test, E_test = load_whas(partition='testing', data_type='np')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These `training` and `testing` partitions correspond to the original\n",
    "partitions used in DeepSurv's paper. \n",
    "However, you could also load the complete dataset using\n",
    "`partition='complete'` and partition it as you wish (e.g., using\n",
    "sklearn's [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# Calculate important parameters.\n",
    "n_patients_train = X_train.shape[0]\n",
    "n_features = X_train.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-process data\n",
    "Data pre-processing is an important step. However, DeepSurvK leaves this\n",
    "to the user, since it depends very much on the data themselves.\n",
    "As mentioned in the [previous notebook]((./00_understanding_deepsurv.ipynb)), \n",
    "at the very least, I would recommend doing standardization and sorting:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# Standardization\n",
    "X_scaler = StandardScaler().fit(X_train)\n",
    "X_train = X_scaler.transform(X_train)\n",
    "X_test = X_scaler.transform(X_test)\n",
    "\n",
    "Y_scaler = StandardScaler().fit(Y_train.reshape(-1, 1))\n",
    "Y_train = Y_scaler.transform(Y_train)\n",
    "Y_test = Y_scaler.transform(Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "Y_train = Y_train.flatten()\n",
    "Y_test = Y_test.flatten()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Notice that if you read/have your data as a `pandas` DataFrame, you will\n",
    "> get an error when reshaping `Y_train` (see [issue #81](https://github.com/arturomoncadatorres/deepsurvk/issues/81)). \n",
    "> That is because a DataFrame doesn't have the `reshape` attribute.\n",
    ">\n",
    "> In such case, you need to do the reshaping as follows:\n",
    ">\n",
    "> ```\n",
    "> Y_scaler = StandardScaler().fit(Y_train.values.reshape(-1, 1))\n",
    "> ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sorting\n",
    "sort_idx = np.argsort(Y_train)[::-1]\n",
    "X_train = X_train[sort_idx]\n",
    "Y_train = Y_train[sort_idx]\n",
    "E_train = E_train[sort_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Notice that if you read/have your data as a `pandas` DataFrame, you will\n",
    "> get an error when sorting (see [issue #82](https://github.com/arturomoncadatorres/deepsurvk/issues/82)). \n",
    "> That is because a DataFrame cannot be sorted like this.\n",
    ">\n",
    "> In such case, you need to do the sorting as follows:\n",
    ">\n",
    "> ```\n",
    "> X_train = X_train.values[sort_idx]\n",
    "> ...\n",
    "> ```\n",
    "\n",
    "## Create a DeepSurvK model\n",
    "When creating an instance of a DeepSurvK model, we can also define its \n",
    "parameters. The only mandatory parameters are `n_features` and `E`.\n",
    "If not defined, the rest of the parameters will use a default.\n",
    "This is, of course, far from optimal, since (hyper)parameter tuning\n",
    "has a *huge* impact on model performance. However, we will deal\n",
    "with that later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsk = deepsurvk.DeepSurvK(n_features=n_features, E=E_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since DeepSurvK is just a Keras model, we can take advantage of all the\n",
    "perks and tools that come with it. For example, we can get an overview\n",
    "of the model architecture very easily."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsk.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks\n",
    "As mentioned earlier, it is practical to use Early Stopping in the\n",
    "case of NaNs in loss values. Additionally, it is also a good idea\n",
    "to use the model that during the training phase yields the lowest loss\n",
    "(which isn't necessarily the one at the end of the training)\n",
    "\n",
    "Both of these practices can be achieved using [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback).\n",
    "DeepSurvK provides a method to generate these two specific callbacks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = deepsurvk.common_callbacks()\n",
    "print(callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Needless to say that you can define your own callbacks as well, of course.\n",
    "\n",
    "## Model fitting\n",
    "After this, we are ready to actually fit our model (as any Keras model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "epochs = 1000\n",
    "history = dsk.fit(X_train, Y_train, \n",
    "                  batch_size=n_patients_train,\n",
    "                  epochs=epochs, \n",
    "                  callbacks=callbacks,\n",
    "                  shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> In some cases, it has been reported that while fitting a model,\n",
    "> the [loss goes to a `NaN` very early](https://github.com/arturomoncadatorres/deepsurvk/issues/83),\n",
    "> making the training process unfeasible, even with the previously defined\n",
    "> callback. I haven't been able to replicate that issue consistently.\n",
    "> \n",
    "> However, this issue has also [been reported in the original DeepSurv](https://github.com/jaredleekatzman/DeepSurv/issues/14).\n",
    "> Apparently, a potentially good solution for this is to *not* \n",
    "> standardize your data during the pre-procesing, but rather\n",
    "> normalizing it (i.e., make sure that features are in the range 0-1).\n",
    "> However, remember that scaling is particularly sensitive to\n",
    "> outliers, so be careful!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DeepSurvK provides a few wrappers to generate visualizations that are\n",
    "often required fast and easy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepsurvk.plot_loss(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model predictions\n",
    "Finally, we can generate predictions using our model.\n",
    "We can evaluate them using the c-index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_pred_train = np.exp(-dsk.predict(X_train))\n",
    "c_index_train = deepsurvk.concordance_index(Y_train, Y_pred_train, E_train)\n",
    "print(f\"c-index of training dataset = {c_index_train}\")\n",
    "\n",
    "Y_pred_test = np.exp(-dsk.predict(X_test))\n",
    "c_index_test = deepsurvk.concordance_index(Y_test, Y_pred_test, E_test)\n",
    "print(f\"c-index of testing dataset = {c_index_test}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:percent"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
